##################################################
#
# pentestlib
#
# Got functions and helpers useful for pentest
#


##################################################
#
# imports
#

import datetime
import inspect
import itertools
import random
import re
import sys

try :
    import netaddr
    import requests
    import termcolor
    requests.packages.urllib3.disable_warnings()

except ImportError as ie:
    print("[-] Failed on import: %s" % ie)
    exit(1)


def info(msg):     print(termcolor.colored("[*] ", "blue", attrs=["bold",] ) + msg)
def ok(msg):       print(termcolor.colored("[+] ", "green", attrs=["bold",] ) + msg)
def warn(msg):     print(termcolor.colored("[!] ", "orange", attrs=["bold",] ) + msg)
def err(msg):      print(termcolor.colored("[!] ", "red", attrs=["bold",] ) + msg)


def now():
    """Return the time of now, suitable for log"""
    return datetime.datetime.now().strftime("%d/%m/%y %H:%M:%S")



##################################################
#
#  Binary stuff
#

def hexdump(source, length=0x10, separator=".", base=0x000000):
    """
    xxd like function

>>> print(hexdump(b'\\x0e\\x82\\t\\x05:\\xd6\\x8c\\xf1\\xc6L\\x94\\xb3PN>\\xfb'))
0x00000000    0e 82 09 05 3a d6 8c f1 c6 4c 94 b3 50 4e 3e fb     ....:....L..PN>.
>>> print(hexdump(b'\\x0e\\x82\\t\\x05:\\xd6\\x8c\\xf1\\xc6L\\x94\\xb3PN>\\xfb', 8))
0x00000000    0e 82 09 05 3a d6 8c f1     ....:...
0x00000008    c6 4c 94 b3 50 4e 3e fb     .L..PN>.
    """
    result = []
    align = 10
    for i in range(0, len(source), length):
        chunk = bytearray(source[i:i + length])
        hexa = " ".join(["{:02x}".format(b) for b in chunk])
        text = "".join([chr(b) if 0x20 <= b < 0x7F else separator for b in chunk])
        result.append("{addr:#0{aw}x}    {data:<{dw}}    {text}".format(aw=align, addr=base+i,
                                                                        dw=3*length, data=hexa,
                                                                        text=text))
    return "\n".join(result)




##################################################
#
#  Web stuff
#

def called():
    """@return called function name as a string"""
    return inspect.stack()[1][3]

def called_line():
    """@return called function line number as a string"""
    return inspect.stack()[1][4]

def caller_name():
    """@return caller function name as a string"""
    return inspect.stack()[2][3]

def caller_line():
    """@return caller function line number as a string"""
    return inspect.stack()[2][4]


def GET(url, headers={}, proxies={}):
    headers["User-Agent"] = "Mozilla/5.0 (compatible; MSIE 7.0; Windows NT 6.0; fr-FR)"
    return requests.get(url, proxies=proxies, headers=headers, verify=False)


def POST(url, data={}, headers={}, proxies={}):
    headers["User-Agent"] = "Mozilla/5.0 (compatible; MSIE 7.0; Windows NT 6.0; fr-FR)"
    return requests.post(url, data=data, proxies=proxies, headers=headers, verify=False)


def TRACE(url, fwd_until=1, headers={}, proxies={}):
    """
    Probe proxy presence through TRACE method
    ref: http://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
    """
    if fwd_until < 1:
        raise ValueError("Max-Forward value is too low")

    def do_trace(sess, max_forward):
        headers["User-Agent"] = "Mozilla/5.0 (compatible; MSIE 7.0; Windows NT 6.0; fr-FR)"
        headers["Max-Forwards"] = max_forward
        req = requests.Request("TRACE", url, headers=headers).prepare()
        return s.send(req, verify=False, proxies=proxies).headers

    response_headers = []
    s = requests.Session()
    for i in range(fwd_until):
        response_headers.append( do_trace(s, i) )

    return response_headers


def html_encode(string, format="html", full=False):
    """Encode a string in different format useful for WebApp PT
    Supported formats are: html/decimal/hexa

    >>> html_encode("<hello> <world>", full=True)
    '%3c%68%65%6c%6c%6f%3e%20%3c%77%6f%72%6c%64%3e'
    >>> html_encode("<hello>", full=True, format="hexa")
    '&#x3c;&#x68;&#x65;&#x6c;&#x6c;&#x6f;&#x3e;'
    >>> html_encode("<hello>", full=True, format="blah")
    Traceback (most recent call last):
    ...
    ValueError: Unknown format
    """
    if format.lower() not in ("html", "decimal", "hexa"):
        raise ValueError("Unknown format")

    res = ""
    for c in string:
        if c == '_':
            res += c
            continue
        if not full :
            if '0'<=c<='9' or 'A'<=c<='Z' or 'a'<=c<='z':
                res += c
                continue

        if format == "html":
            res += "%%%x" % ord(c)
        elif format == "decimal":
            res += "&#%s;" % ord(c)
        elif format == "hexa":
            res += "&#x%x;" % ord(c)

    return res


def html_decode(string, format="html"):
    """Encode a string encoded by @@html_encode
    >>> html_decode("<hello>", format="blah")
    Traceback (most recent call last):
    ...
    ValueError: Unknown format
    >>> html_decode('&#x3c;&#x68;&#x65;&#x6c;&#x6c;&#x6f;&#x3e;', format="hexa")
    '<hello>'
    """
    patterns = {'html':'%', 'decimal':'&#', 'hexa':'&#x'}
    if format not in patterns.keys() :
        raise ValueError("Unknown format")

    res = ""
    if format in ('dec', 'hexa'):
        string = string.replace(';', '')

    patt = patterns[format]
    i = string.find(patt)

    if i == -1:
        return string

    res += string[:i]

    while i != -1:
        j = i + len(patt) + 2
        l = string[i+len(patt):j]
        letter = int(l, 16) if format in ("html", "hexa") else int(l)
        res += chr(letter)

        i = string[j:].find(patt)
        if i != -1 :
            i += j
            res += string[j:i]
        else:
            res += string[j:]
    return res


def mysql_char(s):
    """Translate a string to a concat of MySQL CHAR()
    >>> mysql_char("mysql")
    'CHAR(109,121,115,113,108)'
    """
    return "CHAR(%s)" % ",".join( [ "%d" % ord(c) for c in s ] )


def mssql_char(s):
    """Translate a string to a concat of Ms SQL CHAR()
    >>> mssql_char("mssql")
    'CHAR(109)+CHAR(115)+CHAR(115)+CHAR(113)+CHAR(108)'
    """
    return "+".join( [ "CHAR(%d)" % ord(c) for c in s ] )


class HTTPReq:
    """
    Classe de parsing de requete HTTP Raw -> Object.
    Permet egalement de retourner cette requete Object -> Raw
    """
    CRLF = "\r\n"
    SEP  = ": "

    def __init__(self, req):
        self.method = 'GET'
        self.path = '/'
        self.version = 'HTTP/1.1'
        self.data = ''
        self.headers = {}

        elts = [ x for x in req.split(self.CRLF) if len(x) ]
        self.method, self.path, self.version = elts.pop(0).split(" ")
        if (req.startswith("POST ")):
            self.data = elts.pop()

        for header in elts:
            try :
                h,v = header.split(self.SEP, 1)
            except ValueError:
                h = header
                v = ''
            finally:
                self.headers[h] = v

    def is_header(self, chunk):
        return self.SEP in chunk

    def __str__(self):
        first = ["%s %s %s" % (self.method, self.path, self.version)]
        headers = ["%s%s%s" % (h,self.SEP,v) for (h,v) in self.headers.items()]

        if self.method == 'POST':
            return self.CRLF.join(first+headers+[self.data])

        else :
            return self.CRLF.join(first+headers)



##################################################
#
#   Some Unicode converting functions
#

def ucs_string(chaine, format=2, to_html=False):
    """Unicode converting function
    >>> ucs_string("test")
    ['c1', 'b4', 'c1', 'a5', 'c1', 'b3', 'c1', 'b4']
    >>> ucs_string("test", to_html=True)
    '%c1%b4%c1%a5%c1%b3%c1%b4'
    """
    def ucs_and (a, b):
        i = 0
        while i < len(b):
            a[i] |= b[i]
            i += 1

        res = []
        for i in a :
            res.append('%x' % i)
        return res

    def ucs2(c):
        byte = int('%x' % ord(c), 16)
        parts = []
        parts.append(byte >> 6)   #
        parts.append(byte & 0x3f) # = 0011 1111

        u = []
        u.append(0xC0) # = 1100 0000
        u.append(0x80) # = 1000 0000

        return ucs_and(u, parts)


    def ucs3(c):
        byte = int('%x' % ord(c), 16)

        parts = []
        parts.append(0x00)
        parts.append(byte >> 6)   #
        parts.append(byte & 0x3f) # = 0011 1111

        u = []
        u.append(0xE0) # = 1110 0000
        u.append(0x80) # = 1000 0000
        u.append(0x80) # = 1000 0000

        return ucs_and(u, parts)


    def ucs4(c):
        byte = int('%x' % ord(c), 16)

        parts = []
        parts.append(0x00)
        parts.append(0x00)
        parts.append(byte >> 6)
        parts.append(byte & 0x3f) # = 0011 1111

        u = []
        u.append(0xF0) # = 1111 0000
        u.append(0x80) # = 1000 0000
        u.append(0x80) # = 1000 0000
        u.append(0x80) # = 1000 0000
        return ucs_and(u, parts)


    if format not in range(2,5):
        return -1

    if   format == 2:  func_ucs = ucs2
    elif format == 3:  func_ucs = ucs3
    elif format == 4:  func_ucs = ucs4
    else:
        err("Wtf ?")
        return -1

    res = []
    for c in chaine: res.extend(func_ucs(c))

    if not to_html:
        return res
    else:
        return "%"+"%".join(res)


##################################################
#
# Network functions
#

def expand_cidr(plage, out="/dev/stdout", mode="a"):
    """
    Expands a CIDR range and write result into output file.

    >>> expand_cidr("192.168.0.0/30", "/dev/null", "w")
    0
    >>> expand_cidr("783.128.0.0/13", "/dev/null", "w")
    Traceback (most recent call last):
    ...
    AddrFormatError: invalid IPNetwork 783.128.0.0/13
    """
    regexp = re.compile("([0-9]{1,3}.){3}[0-9]{1,3}/[0-9]{0,2}")
    if not re.match(regexp, plage):
        return -1

    with open(out, mode) as f:
        for ip in netaddr.iter_unique_ips(plage):
            f.write("%s\n" % str(ip))
    return 0


##################################################
#
#   Crypto functions
#

def levenshtein(s, t):
    """
    Compute Levenshtein distance between 2 strings

    >>> levenshtein("tset", "test")
    2
    """
    s = ' ' + s
    t = ' ' + t
    d = {}
    S = len(s)
    T = len(t)
    for i in range(S):
        d[i, 0] = i
    for j in range (T):
        d[0, j] = j
    for j in range(1,T):
        for i in range(1,S):
            if s[i] == t[j]:
                d[i, j] = d[i-1, j-1]
            else:
                d[i, j] = min(d[i-1, j] + 1,
                              d[i, j-1] + 1,
                              d[i-1, j-1] + 1)
    return d[S-1, T-1]



def caesar(plaintext, shift):
    """
    Apply Caesar shift to plaintext

    >>> caesar("Hello World", 13) == 'Uryyb Jbeyq'
    True
    >>> caesar("Hello World", 13) == 'Uryyb JbeZq'
    False
    >>> caesar("Hello World", 26) == "Hello World"
    True
    """
    dico =[chr(i) for i in range(ord('a'), ord('z')+1) ]
    permut = {}
    for i in range(0, len(dico)):
        permut[dico[i]] = dico[(i + shift) % len(dico)]

    res = ""
    for c in plaintext:
        if c.lower() in dico:
            l = permut[c.lower()].upper() if c.isupper() else permut[c.lower()]
        else:
            l = c
        res += l
    return res


def xor(data, key):
    """ XOR a string with key
    >>> xor('my string', 'AAA')
    ',8a253(/&'
    >>> xor(',8a253(/&', 'AAA')
    'my string'
    """
    return ''.join(chr(ord(x) ^ ord(y)) for (x,y) in zip(data, itertools.cycle(key)))


##################################################
#
# SQL
#

def sqlparse(req, encode=None):
    """
    Parse a SQL requet and returns it as a dict()

    >>> sqlparse("SELECT * FROM information_schema.columns WHERE foo = 1")
    {'where': 'foo = 1', 'from': 'information_schema.columns', 'select': '*'}
    """
    keywords = ['select', 'limit', 'from', 'where']
    req = req.lower()
    sql_blocks = {}
    new_req = req

    if encode in ("mysql", "MySQL") :
        convert_func = ascii2mysql
    elif encode in ("mssql", "SQLServer") :
        convert_func = ascii2mssql
    else:
        convert_func = str

    for i in re.findall("'[^']'", req):
        new_req = re.sub(i, convert_func(i.replace("'","")), new_req)
    req = new_req

    for keyword in keywords:
        if keyword in req:
            block_name = keyword
            block_data = ""
            for word in req[req.find(keyword)+len(keyword):].split(" ") :
                if word in keywords : break
                if word !='' : block_data += " " + word
            sql_blocks[block_name] = block_data.strip()

    return sql_blocks



##################################################
#
#   Main is only used for unit testing
#

if __name__ == "__main__":
    import doctest
    doctest.testmod()
